SigmoidCrossEntropyWithLogitsGrad

计算 Sigmoid Cross Entropy With Logits 的梯度。 该算子以 logits 和 labels 为输入,输出对 logits 的梯度值。

数学表达式为:

\[ \begin{align}\begin{aligned}\sigma(x) = \frac{1}{1 + e^{-x}}\\\quad\\dst_i = \sigma(x_i) - y_i\end{aligned}\end{align} \]
其中:
  • \(x_i\) 表示第 i 个 logit(Input0)

  • \(y_i\) 表示第 i 个标签(Input1)

为提高数值稳定性,计算中对正负 logits 采用不同形式:

\[\begin{split}\sigma(x) = \begin{cases} \dfrac{1}{1 + e^{-x}}, & x > 0 \\ \dfrac{e^{x}}{1 + e^{x}}, & x \le 0 \end{cases}\end{split}\]
输入:
  • Input0 - logits 输入数据地址。

  • Input1 - 标签(labels)数据地址。

  • length - 计算长度。

  • core_mask - 核掩码(仅适用于共享存储版本)。

输出:
  • output - 梯度计算结果地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持 fp32 类型

  • MT7004 支持 fp16fp32 类型

  • 输入 logits 与 labels 必须具有相同长度

  • 输出为对 logits 的梯度,不包含对 labels 的梯度

共享存储版本:

void fp_sigmoidcrossentropywithlogitsgrad_s(float *Input0, float *Input1, float *output, int length, int core_mask)
void hp_sigmoidcrossentropywithlogitsgrad_s(half *Input0, half *Input1, half *output, int length, int core_mask)

C调用示例:

 1// FT78NE 示例
 2#include <stdio.h>
 3#include <sigmoid_cross_entropy.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *logits = (float *)0xA0000000;   // DDR 空间
 7    float *labels = (float *)0xA0100000;
 8    float *output = (float *)0xC0000000;
 9
10    int length = 1024;
11    int core_mask = 0xff;
12
13    fp_sigmoidcrossentropywithlogitsgrad_s(logits, labels, output, length, core_mask);
14
15    return 0;
16}

私有存储版本:

void fp_sigmoidcrossentropywithlogitsgrad_p(float *Input0, float *Input1, float *output, int length)
void hp_sigmoidcrossentropywithlogitsgrad_p(half *Input0, half *Input1, half *output, int length)

C调用示例:

 1// FT78NE 示例(私有存储)
 2#include <stdio.h>
 3#include <sigmoid_cross_entropy.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *logits = (float *)0x10810000;   // L2 空间
 7    float *labels = (float *)0x10820000;
 8    float *output = (float *)0x10830000;
 9
10    int length = 1024;
11    fp_sigmoidcrossentropywithlogitsgrad_p( logits, labels, output, length);
12
13    return 0;
14}